{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright &copy; 2015 Ondrej Martinsky, All rights reserved\n",
    "\n",
    "[www.quantandfinancial.com](http://www.quantandfinancial.com)\n",
    "# LU Decomposition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using matplotlib backend: Qt4Agg\n",
      "Populating the interactive namespace from numpy and matplotlib\n"
     ]
    }
   ],
   "source": [
    "%pylab\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "from numpy import array, multiply, dot\n",
    "from scipy.linalg import lu"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gaussian Elimination\n",
    "Code for calculating **X** from **M*X=Y** where *M* is either lower or upper diagonal matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def solve_l(m, y):  # solves x from m*x = y\n",
    "    assert (m==tril(m)).all()        # assert matrix is lower diagonal\n",
    "    assert (m.shape[0]==m.shape[1])  # Assert matrix is square matrix\n",
    "    N=m.shape[0]\n",
    "    x=zeros(N)                      # Vector of roots\n",
    "    for r in range(N):\n",
    "        s = 0\n",
    "        for c in range(r):\n",
    "            s += m[r,c]*x[c]            \n",
    "        x[r] = (y[r]-s) / m[r,r]\n",
    "    assert allclose(dot(m,x), y)    # Check solution\n",
    "    return x\n",
    "\n",
    "def solve_u(m, y):\n",
    "    m2 = fliplr(flipud(m))     # flip matrix LR and UD, so upper diagonal matrix becomes lower diagonal\n",
    "    y2 = y[::-1]               # flip array\n",
    "    x2 = solve(m2, y2)\n",
    "    x = x2[::-1]\n",
    "    assert allclose(dot(m,x), y) # Check solution\n",
    "    return x\n",
    "\n",
    "def solve(m, y):\n",
    "    if (m==tril(m)).all():\n",
    "        return solve_l(m,y)\n",
    "    else:\n",
    "        return solve_u(m,y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Solving using L U decomposition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2 4 1]\n"
     ]
    }
   ],
   "source": [
    "# Unknowns\n",
    "x_org = array([2, 4, 1])\n",
    "print(x_org)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 2 -1  1]\n",
      " [ 3  3  9]\n",
      " [ 3  3  5]]\n"
     ]
    }
   ],
   "source": [
    "# Coefficients\n",
    "m = array([[2,-1,1],[3,3,9],[3,3,5]])\n",
    "print(m)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 1 27 23]\n"
     ]
    }
   ],
   "source": [
    "# Results\n",
    "y = dot(m,x_org)\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Note: matrix dot-product is not commutative, but is associative\n",
    "p, l, u = lu(m, permute_l=False)\n",
    "pl, u = lu(m, permute_l=True)\n",
    "assert (dot(p,l)==pl).all()\n",
    "assert (dot(pl,u)==m).all()\n",
    "assert (pinv(p)==p).all()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 1.          0.          0.        ]\n",
      " [ 0.66666667  1.          0.        ]\n",
      " [ 1.         -0.          1.        ]]\n"
     ]
    }
   ],
   "source": [
    "print(l) # Lower diagonal matrix, zero element above the principal diagonal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 3.  3.  9.]\n",
      " [ 0. -3. -5.]\n",
      " [ 0.  0. -4.]]\n"
     ]
    }
   ],
   "source": [
    "print(u) # Upper diagnonal matrix, zero elements below the principal diagonal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 0.  1.  0.]\n",
      " [ 1.  0.  0.]\n",
      " [ 0.  0.  1.]]\n"
     ]
    }
   ],
   "source": [
    "print(p) # Permutation matrix for \"l\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "assert (l*u==multiply(l,u)).all()          # memberwise multiplication\n",
    "assert (m==dot(dot(p,l),u)).all()          # matrix multiplication, M=LU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "assert (pinv(p)==p).all()\n",
    "#   P*L*U*X = Y\n",
    "#   L*U*X = pinv(P)*Y\n",
    "#   set Z=U*X\n",
    "#   L*Z = P*Y (solve Z)\n",
    "z = solve(l, dot(p,y))\n",
    "#   solve X from U*X=Z\n",
    "x = solve(u, z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 2.  4.  1.]\n"
     ]
    }
   ],
   "source": [
    "assert allclose(x_org,x)\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [Root]",
   "language": "python",
   "name": "Python [Root]"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
